import os 
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class DECLoss(object):
    def __init__(self, param):
        self.k = param['k_knn']
        self.push_scale = param['push_scale']
        self.iso_type = param['iso_type']
        self.data_num = param['data_num']
        self.n_clusters = param['n_clusters']
        
        self.data_dis = torch.zeros((self.data_num, self.k), device=device)
        self.out_dis = torch.zeros((self.data_num, self.k), device=device)
        
    def t_distribute(self, dis, alpha=1.0):

        numerator = 1.0 / (1.0 + (dis**2 / alpha))
        power = float(alpha + 1.0) / 2
        numerator = numerator ** power
        t_dist = (numerator.t() / torch.sum(numerator, 1)).t()

        return t_dist

    def kNNGraph(self, data):

        n_samples = data.shape[0]

        x = data.to(device)
        y = data.to(device)
        m, n = x.size(0), y.size(0)
        
        xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
        yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
        
        dist = xx + yy
        dist.addmm_(1, -2, x, y.t())
        d = dist.clamp(min=1e-8).sqrt()

        kNN_mask = torch.zeros((n_samples, n_samples), device=device)
        s_, indices = torch.sort(d, dim=1)
        kNN_mask.scatter_(1, indices[:, 1:self.k+1], 1)

        return d, kNN_mask.bool()

    # Imposing local isometry within each manifold
    def update(self, data, out, y_pred):
        index_lists = torch.tensor(np.where(y_pred == 0)).to(device).view(-1)
        data_dis, data_knn = self.kNNGraph(data[y_pred == 0])
        out_dis, out_knn = self.kNNGraph(out[y_pred == 0])
        data_dis_masks = data_dis[out_knn].view(-1, self.k)
        out_dis_masks = out_dis[out_knn].view(-1, self.k)

        for i in range(1, self.n_clusters):
            index_list = torch.tensor(np.where(y_pred == i)).to(device).view(-1)
            data_dis, data_knn = self.kNNGraph(data[y_pred == i])
            out_dis, out_knn = self.kNNGraph(out[y_pred == i])
            data_dis_mask = data_dis[out_knn].view(-1, self.k)
            out_dis_mask = out_dis[out_knn].view(-1, self.k)
            
            data_dis_masks = torch.cat((data_dis_masks, data_dis_mask), 0)
            out_dis_masks = torch.cat((out_dis_masks, out_dis_mask), 0)
            index_lists = torch.cat((index_lists, index_list), 0)

        _, idx2 = torch.sort(index_lists)
        data_dis_masks = torch.index_select(data_dis_masks, 0, idx2)
        out_dis_masks = torch.index_select(out_dis_masks, 0, idx2)

        self.data_dis = data_dis_masks / torch.sqrt(torch.tensor(float(data.shape[1])))
        self.out_dis = out_dis_masks / torch.sqrt(torch.tensor(float(out.shape[1])))

    def Loss_alpha(self, data, out, centers_data, centers_out, idx):

        centers_data, _ = self.kNNGraph(centers_data)
        centers_out, _ = self.kNNGraph(centers_out)
        Error1 = (centers_out - centers_data * self.push_scale) / 1
        loss_push = torch.norm(Error1) / self.n_clusters

        if self.iso_type == 'dis':
            data_dis_mask = self.data_dis
            out_dis_mask = self.out_dis

            Error2 = (data_dis_mask - out_dis_mask) / 1
            loss_iso = torch.norm(Error2)
            num = torch.sum(torch.where(torch.abs(Error2) < 0.01, torch.full_like(Error2, 1), torch.full_like(Error2, 0)))

        elif self.iso_type == 'kl':

            data_dis_mask = self.t_distribute(self.data_dis, alpha=10).view(-1)
            out_dis_mask = self.t_distribute(self.out_dis, alpha=1).view(-1)

            loss_iso = F.kl_div(out_dis_mask.log(), data_dis_mask)
            num = self.data_num * self.k

        return loss_push, loss_iso, num
    